# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Implementation of the tfx component functions for the coco captions example."""

import tempfile

import tensorflow as tf
import tensorflow_transform as tft
import tensorflow_transform.beam as tft_beam
from tfx import v1 as tfx


# [START tfx_run_fn]
def run_fn(fn_args: tfx.components.FnArgs) -> None:
  """Build the TF model, train it and export it."""
  # create a model
  model = tf.keras.Sequential()
  model.add(tf.keras.layers.Dense(1, input_dim=10))
  model.compile()

  # train the model on the preprocessed data
  # model.fit(...)

  # Save model to fn_args.serving_model_dir.
  model.save(fn_args.serving_model_dir)


# [END tfx_run_fn]


# [START tfx_preprocessing_fn]
def preprocessing_fn(inputs):
  """Transform raw data."""
  # convert the captions to lowercase
  # split the captions into separate words
  lower = tf.strings.lower(inputs['caption'])

  # compute the vocabulary of the captions during a full pass
  # over the dataset and use this to tokenize.
  mean_length = tft.mean(tf.strings.length(lower))
  # <do some preprocessing with the mean length>

  return {
      'caption_lower': lower,
  }


# [END tfx_preprocessing_fn]

# [START tfx_analyze_and_transform]
if __name__ == "__main__":
  # Test processing_fn directly without the tfx pipeline
  raw_data = [
      {
          "caption": "A bicycle replica with a clock as the front wheel."
      }, {
          "caption": "A black Honda motorcycle parked in front of a garage."
      }, {
          "caption": "A room with blue walls and a white sink and door."
      }
  ]

  # define the feature_spec (in a tfx pipeline this would be generated by a SchemaGen component)
  feature_spec = dict(caption=tf.io.FixedLenFeature([], tf.string))
  raw_data_metadata = tft.DatasetMetadata.from_feature_spec(feature_spec)

  # test out the beam implementation of the
  # processing_fn with AnalyzeAndTransformDataset
  with tft_beam.Context(temp_dir=tempfile.mkdtemp()):
    transformed_dataset, transform_fn = (
      (raw_data, raw_data_metadata)
      | tft_beam.AnalyzeAndTransformDataset(preprocessing_fn))
  transformed_data, transformed_metadata = transformed_dataset
# [END tfx_analyze_and_transform]
